
# solves the full model
function solve_counterfactual(p, data)
    """
    function that solves the model and counterfactual
    """

    if p.run_base

        #################################
        # recover T_base, η_base, and market access
        #################################

        println("Step 1: Solving for productivity and implicit emission price.")

        T_base, η_base, P_base = solve_Tbase(p, data) 
        B_hat = ones(size(data.L_total, 1)) 
        reg_cf_T = data.reg_cf_T
        reg_cf_B = data.reg_cf_B
        w_base, L_base, L_total_base, π_base, P_base = data.w, data.L, data.L_total, data.π, P_base
        
        #################################
        # solve base model
        #################################

        println("Step 2: Solve for baseline equilibrium.")

        w_base, L_base, L_total_base, π_base, P_base, B_base = solve_w(p, data, w_base, T_base, η_base, B_hat)

        if !p.info
            B_base = solve_Bhats(L_base, w_base, P_base, η_base, data, p)
        end

        data = 
            (π = π_base,
            e = (( w_base .* L_base ) ./ η_base ) .* (p.ξ ./ (p.γ .* (1 .- sum(p.ξ))))',
            d = data.d,
            L = L_base,
            L_total = L_total_base,
            L_level = data.L_level,
            w = w_base,
            P = P_base,
            τ = data.τ,
            reg_cf_T = data.reg_cf_T,
            reg_cf_B = data.reg_cf_B)


        #################################
        # save baseline model
        #################################

        L_level = convert(Array, data.L_level)

        # for writing file names
        p.amenities ? a_s = "on" : a_s = "off"
        p.realloc_workers ? m_s = "on" : m_s = "off"
        p.realloc_pollution ? t_s = "on" : t_s = "off"
        p.realloc_production ? ma_s = "on" : ma_s = "off"
        p.abatement ? ab_s = "on" : ab_s = "off"
        p.congestion ? c_s = "on" : c_s = "off"
        p.heterog ? h_s = "on" : h_s = "off"
        p.opt ? opt_s = "opt" : opt_s = "base"
        p.ϵ_vsl > 0 ? vsl_s = "inc" : vsl_s = "cons"
        p.info ? info_s = "on" : info_s = "off"
        
        save_string =
        "opt-$(opt_s)_prod-$(p.β_T)_amen-$(a_s)_abatement-$(ab_s)_transport-$(t_s)_mobility-$(m_s)_mktacc-$(ma_s)_cong-$(c_s)_heterog-$(h_s)_info-$(info_s)_theta-$(p.θ)_alpha-$(p.α)_gamma-$(p.γ)_iota-$(p.ι)_xi_mult-$(p.ξ_mult)_vsl-$(vsl_s)"
        
        save("data/counterfactual/base_model_$(save_string).jld2",
        Dict(
            "π" => π_base,
            "τ" => data.τ,
            "e" => data.e,
            "damage_matrices" => data.d,
            "L_level" => L_level,
            "reg_cf_T" => reg_cf_T,
            "reg_cf_B" => reg_cf_B,
            "w_base" => w_base,
            "L_base" => L_base,
            "L_total_base" => L_total_base,
            "π_base" => π_base,
            "P_base" => P_base,
            "T_base" => T_base,
            "η_base" => η_base,
            "B_base" => B_base
            )
            )

    else

        reg_cf_T = data.reg_cf_T
        reg_cf_B = data.reg_cf_B

        # for writing file names
        p.amenities ? a_s = "on" : a_s = "off"
        p.realloc_workers ? m_s = "on" : m_s = "off"
        p.realloc_pollution ? t_s = "on" : t_s = "off"
        p.realloc_production ? ma_s = "on" : ma_s = "off"
        p.abatement ? ab_s = "on" : ab_s = "off"
        p.congestion ? c_s = "on" : c_s = "off"
        p.heterog ? h_s = "on" : h_s = "off"
        p.opt ? opt_s = "opt" : opt_s = "base"
        p.ϵ_vsl > 0 ? vsl_s = "inc" : vsl_s = "cons"
        p.info ? info_s = "on" : info_s = "off"
        
        save_string =
        "opt-$(opt_s)_prod-$(p.β_T)_amen-$(a_s)_abatement-$(ab_s)_transport-$(t_s)_mobility-$(m_s)_mktacc-$(ma_s)_cong-$(c_s)_heterog-$(h_s)_info-$(info_s)_theta-$(p.θ)_alpha-$(p.α)_gamma-$(p.γ)_iota-$(p.ι)_xi_mult-$(p.ξ_mult)_vsl-$(vsl_s)"

    end

    #################################
    # load baseline model
    #################################

    base_model = load("data/counterfactual/base_model_$(save_string).jld2")

    data = 
        (π = base_model["π_base"],
         d = base_model["damage_matrices"],
         e = base_model["e"],
         L = base_model["L_base"],
         L_total = base_model["L_total_base"],
         L_level = base_model["L_level"],
         w = base_model["w_base"],
         P = base_model["P_base"],
         τ = base_model["τ"],
         reg_cf_T = base_model["reg_cf_T"],
         reg_cf_B = base_model["reg_cf_B"])

    T_base = base_model["T_base"]
    η_base = base_model["η_base"]
    B_base = base_model["B_base"]
    w_base, L_base, L_total_base, π_base, P_base = data.w, data.L, data.L_total, data.π, data.P

    #################################
    # solve counterfactual model
    #################################

    println("Step 3: Solve for counterfactual equilibrium.")

    # initialize counterfactual productivity and change in amenities
    T_c = deepcopy(T_base)
    η_c = deepcopy(η_base)

    # initialize change in amenities
    B_hat = ones(size(data.L_total))

    # running optimal counterfactual
    # saved output is actual versus optimal
    if p.opt

        # undo negative productivity shock from nonattainment
        T_c = T_c .* data.reg_cf_T

        w_c = data.w
        max_error = Inf

        # no regulation emission prices used as lower bound
        # this forces 'first-best' tax to be positive
        η_no_reg = η_c ./ data.reg_cf_B

        # get new emission prices given labor distribution, solve model with new prices, get new labor distribution
        # repeat until convergence
        while max_error > p.tol

            η_c_old = deepcopy(η_c)

            # price emissions optimally given labor allocation,
            # but cant set prices lower than existing regulations
            η_c = solve_md(L_total_base, η_base, data, p)
            η_c = max.(η_c, η_no_reg)

            # solve for new equilibrium
            w_c, L_base, L_total_base, π_c, P_c, B_hat = solve_w(p, data, w_c, T_c, η_c, B_hat, true, L_base, L_total_base)

            error = abs.(η_c_old .- η_c) ./ η_c_old
            error[isnan.(error)] .= 0
            max_error = maximum(error)
            println("Maximum log10 relative error in solving for optimal emissions prices $(round(log10(max_error), sigdigits = 4)).")
        end

        # solve one more time
        w_c, L_c, L_total_c, π_c, P_c, B_hat = solve_w(p, data, w_c, T_c, η_c, B_hat, true, L_base, L_total_base)

        # recover original base values
        w_base, L_base, L_total_base, π_base, P_base = data.w, data.L, data.L_total, data.π, data.P


    # running regular counterfactual of no nonattainment
    # output is actual versus no nonattainment
    else

        T_c = T_c .* data.reg_cf_T
        
        if p.abatement
            # empirical estimates are negative of the actual parameter
            # so we need to divide here
            η_c = η_c ./ data.reg_cf_B
        end

        # solve counterfactual
        w_c, L_c, L_total_c, π_c, P_c, B_hat = solve_w(p, data, data.w, T_c, η_c, B_hat, true, L_base, L_total_base)

        if !p.info
            B_hat = solve_Bhats(L_c, w_c, P_c, η_c, data, p)
        end

    end

   
    if p.ramp

        #################################
        # run counterfactual for single county shock
        #################################

        save("data/counterfactual/base_model_ramp_standard.jld2",
        Dict(
            "π" => π_c,
            "τ" => data.τ,
            "e" => w_c .* L_c ./ η_c .* (p.ξ ./ (p.γ .* (1 .- sum(p.ξ))))',
            "damage_matrices" => data.d,
            "reg_cf_T" => reg_cf_T,
            "reg_cf_B" => reg_cf_B,
            "w_base" => w_c,
            "L_total_base" => L_total_c,
            "L_base" => L_c,
            "π_base" => π_c,
            "P_base" => P_c,
            "T_base" => T_c,
            "η_base" => η_c
            )
            )

        # # solve model
        levels = push!(collect(1:-.01:0), 2022.0, Inf)
        options_in = Vector{Tuple{NamedTuple,Any}}
        options_in = [(p, level) for level in levels]

        map(
            (args) -> solve_ramp_standard(args...),
            options_in
        ) 

    else

        #################################
        # compute welfare results for saving
        #################################

        println("Step 4: Writing results.")

        if p.congestion
            L_c_county = repeat(sum(reshape(L_total_c', 3, :), dims = 1), 3)
            L_base_county = repeat(sum(reshape(L_total_base', 3, :), dims = 1), 3)

            L_c_county = reshape(L_c_county, :, 1)
            L_base_county = reshape(L_base_county, :, 1)

            L_hat = L_c_county ./ L_base_county
            B_hat = B_hat ./ L_hat.^p.ζ_c
        end

        # option value
        π_hat = (diag(π_base) ./ diag(π_c))

        # replace 0/0 with 1
        π_hat[isnan.(π_hat)] .= 1

        # get consumption price indices
        Pc_c = solve_P_cons(p, P_c)
        Pc_base = solve_P_cons(p, P_base)

        w_hat = w_base ./ w_c
        P_hat = Pc_base ./ Pc_c

        v_hat = (w_hat ./ P_hat)

        # no change in nonemployed consumption payoff
        v_hat = reshape(v_hat, 2, :)
        v_hat = vcat(v_hat, ones(size(v_hat[1, :]')))
        v_hat = reshape(v_hat, :, 1)

        # invert so factual relative to counterfactual, B_hat is currently counterfactual relative to factual
        welfare_hat = ( ( v_hat ) ./ B_hat ) ./ π_hat.^(1 ./ p.ι)

        # population-weighted welfare: ∑_iζ_c_i
        Lm_base = L_total_base[1:3:end]
        Lo_base = L_total_base[2:3:end]
        Lu_base = L_total_base[3:3:end]
        Lm_c = L_total_c[1:3:end]
        Lo_c = L_total_c[2:3:end]
        Lu_c = L_total_c[3:3:end]

        # amenties welfare
        amenity_nonattain = prod(data.reg_cf_B[1:2:end,:], dims = 2) .!= 1
        amenity_nonattain = repeat(vec(amenity_nonattain), inner = 3)
        amenity_attain = convert(BitVector, 1 .- amenity_nonattain)

        welfare_hat_amenity =  1 ./ B_hat 
        welfare_amenity_attain = L_total_base[amenity_attain]' * welfare_hat_amenity[amenity_attain] ./ sum(L_total_base[amenity_attain])
        welfare_amenity_nonattain = L_total_base[amenity_nonattain]' * welfare_hat_amenity[amenity_nonattain] ./ sum(L_total_base[amenity_nonattain])
        welfare_amenity_total = L_total_base' * welfare_hat_amenity
        welfare_amenity_m = Lm_base' * welfare_hat_amenity[1:3:end] / sum(Lm_base)
        welfare_amenity_o = Lo_base' * welfare_hat_amenity[2:3:end] / sum(Lo_base)
        welfare_amenity_u = Lu_base' * welfare_hat_amenity[3:3:end] / sum(Lu_base)

        # consumption welfare
        welfare_consumption = ( (w_base./Pc_base) ./ (w_c./Pc_c) )
        welfare_consumption = reshape(welfare_consumption, 2, :)
        welfare_consumption = vcat(welfare_consumption, ones(size(welfare_consumption[1, :]')))
        welfare_consumption = reshape(welfare_consumption, :, 1)
        welfare_consumption_total = L_total_base' * welfare_consumption
        welfare_consumption_m = Lm_base' * welfare_consumption[1:3:end] / sum(Lm_base)
        welfare_consumption_o = Lo_base' * welfare_consumption[2:3:end] / sum(Lo_base)
        welfare_consumption_u = Lu_base' * welfare_consumption[3:3:end] / sum(Lu_base)
        welfare_cons_attain = L_total_base[amenity_attain]' * welfare_consumption[amenity_attain] ./ sum(L_total_base[amenity_attain])
        welfare_cons_nonattain = L_total_base[amenity_nonattain]' * welfare_consumption[amenity_nonattain] ./ sum(L_total_base[amenity_nonattain])

        # total welfare
        welfare_hat_m = welfare_hat[1:3:end]
        welfare_hat_o = welfare_hat[2:3:end]
        welfare_hat_u = welfare_hat[3:3:end]
        welfare_total = L_total_base' * welfare_hat
        welfare_m = Lm_base' * welfare_hat_m / sum(Lm_base)
        welfare_o = Lo_base' * welfare_hat_o / sum(Lo_base)
        welfare_u = Lu_base' * welfare_hat_u / sum(Lu_base)
        welfare_attain = L_total_base[amenity_attain]' * welfare_hat[amenity_attain] ./ sum(L_total_base[amenity_attain])
        welfare_nonattain = L_total_base[amenity_nonattain]' * welfare_hat[amenity_nonattain] ./ sum(L_total_base[amenity_nonattain])

        # emissions
        e_c =  w_c .* L_c ./ η_c .* (p.ξ ./ (p.γ .* (1 .- sum(p.ξ))))'
        e_hat = data.e ./ e_c
        e_hat[isnan.(e_hat)] .= 1

        # employment
        L_hat = L_total_base - L_total_c

        # emissions prices

        #################################
        # save results
        #################################

        # for writing file names
        p.amenities ? a_s = "on" : a_s = "off"
        p.realloc_workers ? m_s = "on" : m_s = "off"
        p.realloc_pollution ? t_s = "on" : t_s = "off"
        p.realloc_production ? ma_s = "on" : ma_s = "off"
        p.abatement ? ab_s = "on" : ab_s = "off"
        p.congestion ? c_s = "on" : c_s = "off"
        p.heterog ? h_s = "on" : h_s = "off"
        p.opt ? opt_s = "opt" : opt_s = "base"
        p.ϵ_vsl > 0 ? vsl_s = "inc" : vsl_s = "cons"
        p.info ? info_s = "on" : info_s = "off"
        
        save_string =
        "opt-$(opt_s)_prod-$(p.β_T)_amen-$(a_s)_abatement-$(ab_s)_transport-$(t_s)_mobility-$(m_s)_mktacc-$(ma_s)_cong-$(c_s)_heterog-$(h_s)_info-$(info_s)_theta-$(p.θ)_alpha-$(p.α)_gamma-$(p.γ)_iota-$(p.ι)_xi_mult-$(p.ξ_mult)_vsl-$(vsl_s)"

        # save results: aggregate values
        welfare_df = DataFrame(
                                total = Any[], # total welfare
                                welfare_attain = Any[], welfare_nonattain = Any[],
                                amenity_total = Any[], amenity_attain = Any[], amenity_nonattain = Any[],
                                consumption_total = Any[], cons_attain = Any[], cons_nonattain = Any[],
                                manufacturing = Any[], other = Any[], unemp = Any[],   # industry-specific welfare
                                manu_cons = Any[], other_cons = Any[], unemp_cons = Any[],
                                manu_amenity = Any[], other_amenity = Any[], unemp_amenity = Any[],
                                m_emp_hat = Any[], o_emp_hat = Any[], u_emp_hat = Any[],   # industry-specific change in pop
                                m_emp_percent = Any[], o_emp_percent = Any[], u_emp_percent = Any[] # industry-specific percent change in pop
                                )
        push!(welfare_df, [welfare_total[1] welfare_attain[1] welfare_nonattain[1] welfare_amenity_total[1] #=
                        =# welfare_amenity_attain[1] welfare_amenity_nonattain[1] welfare_consumption_total[1] #=
                        =# welfare_cons_attain[1] welfare_cons_nonattain[1] welfare_m welfare_o welfare_u welfare_consumption_m[1] #=
                        =# welfare_consumption_o[1] welfare_consumption_u[1] welfare_amenity_m[1] welfare_amenity_o[1] welfare_amenity_u[1] #= 
                        =# sum(Lm_base - Lm_c) sum(Lo_base - Lo_c) sum(Lu_base - Lu_c) sum(Lm_base - Lm_c)/sum(Lm_base) sum(Lo_base - Lo_c)/sum(Lo_base) sum(Lu_base - Lu_c)/sum(Lu_base)])
                        
        CSV.write("data/counterfactual/welfare_total_" * save_string * ".csv", welfare_df)

        # save results: market-specific values
        welfare_df = DataFrame(fips = Any[], industry = Any[], welfare = Any[])

        crosswalk = Array(CSV.read("data/simulation_fips_crosswalk.csv", DataFrame))

        fips = reshape([crosswalk crosswalk crosswalk]', :, 1)
        industry = repeat(["manufacturing"; "other"; "unemployed"], length(crosswalk))

        # bring in unemployed
        e_hat = reshape(e_hat, 2, :)
        e_hat = vcat(e_hat, ones(size(e_hat[2, :]')))
        e_hat = reshape(e_hat, :, 5)
        e_c = reshape(e_c, 2, :)
        e_c = vcat(e_c, zeros(size(e_c[2, :]')))
        e_c = reshape(e_c, :, 5)

        η_c = reshape(η_c, 2, :)
        η_c = vcat(η_c, zeros(size(η_c[2, :]')))
        η_c = reshape(η_c, :, 5)

        η_base = reshape(η_base, 2, :)
        η_base = vcat(η_base, zeros(size(η_base[2, :]')))
        η_base = reshape(η_base, :, 5)

        Pc_c = reshape(Pc_c, 2, :)
        Pc_c = vcat(Pc_c, zeros(size(Pc_c[2, :]')))
        Pc_c = reshape(Pc_c, :, 1)
        Pc_base = reshape(Pc_base, 2, :)
        Pc_base = vcat(Pc_base, zeros(size(Pc_base[2, :]')))
        Pc_base = reshape(Pc_base, :, 1)

        w_c = reshape(w_c, 2, :)
        w_c = vcat(w_c, zeros(size(w_c[2, :]')))
        w_c = reshape(w_c, :, 1)
        w_base = reshape(w_base, 2, :)
        w_base = vcat(w_base, zeros(size(w_base[2, :]')))
        w_base = reshape(w_base, :, 1)

        P_base = reshape(P_base, 2, :)
        P_base = vcat(P_base, zeros(size(P_base[2, :]')))
        P_base = reshape(P_base, :, 1)

        
        welfare_df = DataFrame([fips industry welfare_hat L_hat #=
                            =# B_hat.^(-1) e_hat[:,1] e_hat[:,2] e_hat[:,3] e_hat[:,4] e_hat[:,5] #=
                            =# e_c[:,1] e_c[:,2] e_c[:,3] e_c[:,4] e_c[:,5] #=
                            =# η_c[:,1] η_c[:,2] η_c[:,3] η_c[:,4] η_c[:,5] #=
                            =# η_base[:,1] η_base[:,2] η_base[:,3] η_base[:,4] η_base[:,5] #=
                            =# v_hat Pc_base./Pc_c w_base./w_c L_total_base L_total_c w_base w_c Pc_base Pc_c B_base P_base], :auto)

        rename!(welfare_df, [:fips, :industry, :welfare, :labor, :amenities, #=
                =# :emissions_nh3, :emissions_nox, :emissions_pm25, :emissions_so2, :emissions_voc, #= 
                =# :cf_emissions_nh3, :cf_emissions_nox, :cf_emissions_pm25, :cf_emissions_so2, :cf_emissions_voc, #=
                =# :cf_emissions_price_nh3, :cf_emissions_price_nox, :cf_emissions_price_pm25, :cf_emissions_price_so2, :cf_emissions_price_voc, #=
                =# :base_emissions_price_nh3, :base_emissions_price_nox, :base_emissions_price_pm25, :base_emissions_price_so2, :base_emissions_price_voc, #=
                =# :real_wage, :price, :nominal_wage, :base_labor, :cf_labor, :base_wage, :cf_wage, :base_c_price, :cf_c_price, :base_amen, :base_ind_price])

        CSV.write("data/counterfactual/welfare_" * save_string * ".csv", welfare_df)

    end

    return Nothing

end

# returns the unobserved factual productivity by:
# 1. Iterate on MA fixed point to solve for the price index
# 2. Insert into output equation along with observed data to recover factual productivity
function solve_Tbase(p, data)

    # solve for implicit emission price as a function of observed data
    η_base = ( data.w .* data.L ./ data.e ) .* (p.ξ ./ (p.γ .* (1 .- sum(p.ξ))))'

    # build emission elasticities for multi-industry / multi-pollutant
    xis = repeat([p.ξ'; 0*ones(length(p.ξ))'], length(data.L) ÷ 2)

    # if a manufacturing location is not emitting, then its emission elasticity is 0
    xis[isinf.(η_base)] .= 0

    cma, fma = solve_ma(data, p)

    # 1 - Σξ is the productive factor share
    # Π η^ξ is the input bundle cost of emissions
    T_base = (data.L.^(1 - p.ζ_a*p.θ) .* data.w.^(1+p.θ*p.γ*(1 - sum(p.ξ))) .* prod(η_base.^(xis*p.θ), dims = 2) ./ fma)::Matrix{Float64}

    return T_base, η_base, cma.^(-1 / p.θ)

end

# returns new level of amenities given labor distribution
function solve_Bhats(L_c, w_c, P_c, η_c, data, p)

    # only diagonal if not reallocating
    if p.realloc_pollution
        d = data.d
    else
        d = [Diagonal(dam) for dam in data.d]
    end

    # no emissions for pollutants with zero share
    e_base = data.e
    @views e_base[:, p.ξ .== 0] .= 0

    # generate total receptor damages: counterfactual
    # emissions given the implicit emission price and labor income: increasing in income, decreasing in price
    e_c = (( w_c .* L_c ) ./ η_c ) .* (p.ξ ./ (p.γ .* (1 .- sum(p.ξ))))'::Adjoint{Float64, Vector{Float64}}
    e_c[isnan.(e_c)] .= 0

    # compute total income in a location
    # gets us industry-location income
    income_c = reshape(w_c ./ P_c, 2, :)::Matrix{Float64}
    L_c = reshape(L_c, 2, :)::Matrix{Float64}

    # assign income of nonemployed to location-weighted average
    income_c = [income_c; sum(income_c .* L_c ./ sum(L_c, dims = 1), dims = 1)]
    income_c = reshape(income_c, :, 1)

    nominal_wage_c = reshape(w_c, 2, :)
    nominal_wage_c = [nominal_wage_c; sum(nominal_wage_c .* L_c ./ sum(L_c, dims = 1), dims = 1)]
    nominal_wage_c = reshape(nominal_wage_c, :, 1)

    income_base = reshape(data.w ./ data.P, 2, :)::Matrix{Float64}
    L_base = reshape(data.L, 2, :)::Matrix{Float64}

    # assign income of nonemployed to location-weighted average
    income_base = [income_base; sum(income_base .* L_base ./ sum(L_base, dims = 1), dims = 1)]
    income_base = reshape(income_base, :, 1)

    nominal_wage_base = reshape(data.w, 2, :)
    nominal_wage_base = [nominal_wage_base; sum(nominal_wage_base .* L_base ./ sum(L_base, dims = 1), dims = 1)]
    nominal_wage_base = reshape(nominal_wage_base, :, 1)


    # sum(emissions*MD) for each pollutant
    recep_poll_dam_c = vcat([sum(d[i].*e_c[1:2:end,i], dims = 1) for i in eachindex(p.β_E)]...)
    # sum damages of all pollutants
    total_recep_dam_c = reshape((sum(recep_poll_dam_c, dims = 1)' .* [1 1 1])', :, 1)

    # put into consumption-equivalent terms
    receptor_damages_c = 1 .- total_recep_dam_c ./ income_c .* (nominal_wage_c / p.vsl_inc).^p.ϵ_vsl

    recep_poll_dam_base = vcat([sum(d[i].*e_base[1:2:end,i], dims = 1) for i in eachindex(p.β_E)]...)
    total_recep_dam_base = reshape((sum(recep_poll_dam_base, dims = 1)' .* [1 1 1])', :, 1)

    receptor_damages_base = 1 .- total_recep_dam_base ./ income_base .* (nominal_wage_base / p.vsl_inc).^p.ϵ_vsl

    # damage in consumption equivalent terms in counterfactual relative to base
    amenity_change = (receptor_damages_c ./ receptor_damages_base)::Matrix{Float64}
    
    # 0/0 -> no change in amenities
    amenity_change[isnan.(amenity_change)] .= 1.

    # to avoid issues early on in the algorithm if guesses are bad
    # never actually leads to a constrained solution
    amenity_change[amenity_change .<= .01] .= .01
    amenity_change[amenity_change .>= 1.5] .= 1.5

    return amenity_change

end

function solve_w(p, data, w_c, T_c, η_c, B_hat, cf = false, L_base = data.L, L_total_base = data.L_total)

    # build emission elasticities for multi-industry / multi-pollutant
    xis = repeat([p.ξ'; 0*ones(length(p.ξ))'], length(data.L) ÷ 2)

    # if a manufacturing location is not emitting, then its emission elasticity is 0
    xis[isinf.(η_c)] .= 0

    P_c = data.P
    Pc_c = solve_P_cons(p, P_c)
    π_c = similar(data.π)
    B_hat_old = B_hat
    L_c = L_base
    L_old = L_base
    L_total_c = L_total_base
    error_out = Inf
    iters = 1

    # for fixing CMA/FMA for no trade reallocation
    cma_c = convert(Vector{Float64}, collect(1:size(data.L)[1]))
    fma_c = convert(Vector{Float64}, collect(1:size(data.L)[1]))
    expenditures_base = sum(reshape(data.w .* data.L ./ (p.γ*(1 - sum(p.ξ))), 2, :)', dims = 2) .* [p.α 1. - p.α]
    expenditures_base = reshape(expenditures_base', 1, :)'
    cma_c = P_c .^ (-p.θ)
    cma_c = cma_c / sqrt(sum(cma_c .^ 2))
    fma_c = vec((data.τ .^ (-p.θ)) * (cma_c .^ (-1) .* expenditures_base))
    fma_c = fma_c / sqrt(sum(fma_c .^ 2))

    # iterate on MA and wage equations until convergence
    while error_out > p.tol && iters < 1e6

        # Step 1: industry prices
        if !cf || p.realloc_production
            cma_c, fma_c = solve_ma(data, p, w_c, L_c, fma_c)
            P_c = @. cma_c ^ (-1 / p.θ)
        end

        # Step 2: amenities
        if p.amenities && p.info
            B_hat = solve_Bhats(L_c, w_c, P_c, η_c, data, p)
        else
            B_hat = ones(size(B_hat))
        end

        # Step 3: mobility shares and labor
        if !p.realloc_workers && cf

            # fix migration shares --> fix labor
            π_c = data.π
            L_total_c = L_total_base
            L_c = L_base

        else

            # get counterfactual migration shares
            π_c = solve_π(p, data, w_c, P_c, B_hat, L_total_c)

            # get counterfactual labor given initial labor distribution
            L_total_c = (π_c') * L_total_c
            L_c = reshape(reshape(L_total_c, 3, :)[1:2, :], :, 1)

        end

        # Step 4: wages
        w_new = (fma_c .* T_c ./ L_c.^(1 - p.ζ_a*p.θ) ./ prod((η_c).^(xis*p.θ), dims = 2)).^(1 / (1 .+ p.θ.*p.γ.*(1 .- sum(p.ξ))))

        # Step 5: error check and iterate until convergence
        Pc_new = solve_P_cons(p, P_c)
        error_out = maximum(abs.((w_new./Pc_new .- w_c./Pc_c)./(w_c./Pc_c)))
        error_out_L = maximum(filter(x -> !(isnan(x) || isinf(x)), abs.((L_c - L_old)./L_old)))
        error_out_B = maximum(abs.((B_hat .- B_hat_old)./B_hat_old))

        println("Iter $(iters): Maximum log10 relative error in solving for real wage, amenities change, labor: ($(round(log10(error_out), sigdigits = 6)), $(round(log10(error_out_B), sigdigits = 6)), $(round(log10(error_out_L), sigdigits = 6))).")
        w_c = p.damp .* w_new .+ (1 - p.damp) .* w_c
        Pc_c = Pc_new
        B_hat_old = B_hat
        L_old = L_c

        iters += 1

    end

    return w_c, L_c, L_total_c, π_c, P_c, B_hat

end


function solve_ma(data, p, w_c = data.w, L_c = data.L, fma_temp = convert(Vector{Float64}, collect(1:size(data.L)[1])))

    fma_temp = fma_temp / sqrt(sum(fma_temp .^ 2))
    income = w_c .* L_c ./ (p.γ*(1 - sum(p.ξ)))
    expenditures = sum(reshape(income, 2, :)', dims = 2) .* [p.α 1. - p.α]
    expenditures = reshape(expenditures', 1, :)'

    # τ is symmetric so transposes dont matter
    cma_new = (data.τ .^ (-p.θ)) * (fma_temp .^ (-1) .* income)
    cma_new = cma_new / sqrt(sum(cma_new .^ 2))
    cma_temp = cma_new / sqrt(sum(cma_new .^ 2))
    fma_new = (data.τ .^ (-p.θ)) * (cma_temp .^ (-1) .* expenditures)
    fma_new = fma_new / sqrt(sum(fma_new .^ 2))
    fma_error = sqrt.(sum((fma_new .- fma_temp) .^ 2))
    fma_temp = fma_new / sqrt(sum(fma_new .^ 2))

    # use Fujimoto-Krause to solve for MA up to a normalization
    i = 1
    while fma_error > p.tol

        cma_new = (data.τ .^ (-p.θ)) * (fma_temp .^ (-1) .* income)
        cma_new = cma_new / sqrt(sum(cma_new .^ 2))
        fma_new = (data.τ .^ (-p.θ)) * (cma_new .^ (-1) .* expenditures)
        fma_new = fma_new / sqrt(sum(fma_new .^ 2))
        fma_error = sqrt.(sum((fma_new .- fma_temp) .^ 2))

        println("Maximum log10 relative error in solving for FMA $(round(log10(fma_error), sigdigits = 3)).")

        cma_temp = cma_new 
        fma_temp = fma_new 

        i += 1
        if i == 2000
            break
        end

    end

    return cma_new, fma_new
    
end


# returns the consumption price index given the industry prices
function solve_P_cons(p, P_industry)

    # compute consumption price index for each market
    Pc_cons = prod((reshape(P_industry, 2, :) ./ [p.α; 1-p.α]).^([p.α; 1-p.α]), dims = 1)'

    # Workers in different industries in same location face the same price index
    Pc_cons = reshape([Pc_cons'; Pc_cons'], :, 1)

    return Pc_cons
end

# returns the industry price index (up to a normalization)
# given a set of wages, labor, and productivity
function solve_P(p, data, w_c, L_c, T_c, η_c, xis)
    fma = L_c.^(1 - p.ζ_a*p.θ) .* w_c.^(1 + p.θ.*p.γ*(1 - sum(p.ξ))) .* prod(η_c.^(xis*p.θ), dims = 2) ./ T_c
    cma = vec((data.τ .^ (-p.θ)) * (fma .^ (-1) .* (w_c .* L_c ./ (p.γ*(1 - sum(p.ξ)))) ))
    cma = cma / sqrt(sum(cma .^ 2))
    return (cma).^(-1/p.θ)
end

# returns the counterfactual migration shares
# given a set of counterfactual wages, counterfactual amenities,
# and observed migration shares which identifies the unobserved migration costs
function solve_π(p, data, w_c, P_c, B_hat, L_total_c)

    # get consumption price indices
    Pc_base = solve_P_cons(p, data.P)
    Pc_c = solve_P_cons(p, P_c)

    # get hat variables
    w_hat = @. w_c / data.w
    
    if p.congestion
        L_c_county = repeat(sum(reshape(L_total_c', 3, :), dims = 1), 3)
        L_base_county = repeat(sum(reshape(data.L_total', 3, :), dims = 1), 3)

        L_c_county = reshape(L_c_county,:,1)
        L_base_county = reshape(L_base_county,:,1)

        L_hat = @. L_c_county / L_base_county
    else
        L_hat = ones(size(L_total_c))
    end

    P_hat = @. Pc_c / Pc_base

    v_hat = @. (w_hat / P_hat)

    # no change in nonemployed consumption payoff
    v_hat = reshape(v_hat, 2, :)
    v_hat = vcat(v_hat, ones(size(v_hat[1, :]')))
    v_hat = reshape(v_hat, :, 1)

    num = @. (((v_hat * B_hat / L_hat^p.ζ_c)))^p.ι
    den = data.π * num 
    out = @. ((num') * data.π) / den

    return out

end

# returns marginal damage per ton given labor distribution
function solve_md(L_total_base, η_base, data, p)

    # location-specific population
    L_total_base = reshape(L_total_base, 3, :)::Matrix{Float64}
    L_total_base = sum(L_total_base, dims = 1)

    # only diagonal if not reallocating
    if  p.realloc_pollution
        d = hcat([dam * L_total_base' for dam in data.d]...)
    else
        d = hcat([diag(dam) .* L_total_base' for dam in data.d]...)
    end

    η_opt = Inf*ones(size(η_base))
    η_opt[1:2:end,:] .= d
    η_opt[isinf.(η_base)] .= Inf

    return η_opt

end

function solve_ramp_standard(p, percent::Real)

    ################################
    # change standards by fixed percent
    #################################

    ################################
    # load counterfactual model as
    # new baseline
    #################################

    base_model = load("data/counterfactual/base_model_ramp_standard.jld2")

    data = 
        (π = base_model["π_base"],
        d = base_model["damage_matrices"],
        e = base_model["e"],
        L = base_model["L_base"],
        L_total = base_model["L_total_base"],
        w = base_model["w_base"],
        P = base_model["P_base"],
        τ = base_model["τ"],
        reg_cf_T = base_model["reg_cf_T"],
        reg_cf_B = base_model["reg_cf_B"])

    T_c = base_model["T_base"]
    η_c = base_model["η_base"]
    w_c, L_c, L_total_c, π_c, P_c = data.w, data.L, data.L_total, data.π, data.P

    # counterfactual regulation
    pollution_levels =
        CSV.read("data/concentrations_criteria_1997.csv", DataFrame)

    if percent < 1.0

        # vector of indicators for whether the county-industry is above the counterfactual standard
        in_nonattainment = zeros(length(T_c))
        in_nonattainment = convert(BitVector, in_nonattainment)
        in_nonattainment[1:2:end] = (
            pollution_levels.co_1h .> pollution_levels.co_1h_std*percent 
            .|| pollution_levels.co_8h .> pollution_levels.co_8h_std*percent 
            .|| pollution_levels.no2_annual .> pollution_levels.no2_annual_std*percent 
            .|| pollution_levels.o3_1h .> pollution_levels.o3_1h_std*percent 
            .|| pollution_levels.so2_annual .> pollution_levels.so2_annual_std*percent 
            .|| pollution_levels.so2_24h .> pollution_levels.so2_24h_std*percent
            .|| pollution_levels.so2_3h .> pollution_levels.so2_3h_std*percent
            .|| pollution_levels.pm10_24h .> pollution_levels.pm10_24h_std*percent 
            )

        # if percent ≤ 1 --> standard is getting tighter
        # in nonattainment if above the standard, or was already in nonattainment in the data
        in_nonattainment = vec(in_nonattainment .|| sum(data.reg_cf_B .- 1, dims = 2) .> 0)

    elseif percent == 2022.0

        # vector of indicators for whether the county-industry is above the counterfactual standard
        in_nonattainment = zeros(length(T_c))
        in_nonattainment = convert(BitVector, in_nonattainment)
        in_nonattainment[1:2:end] = (
            pollution_levels.co_1h .> pollution_levels.co_1h_std_2022
            .|| pollution_levels.co_8h .> pollution_levels.co_8h_std_2022
            .|| pollution_levels.no2_1h .> pollution_levels.no2_1h_std_2022
            .|| pollution_levels.no2_annual .> pollution_levels.no2_annual_std_2022
            .|| pollution_levels.o3_1h .> pollution_levels.o3_1h_std
            .|| pollution_levels.o3_8h .> pollution_levels.o3_8h_std_2022
            .|| pollution_levels.so2_annual .> pollution_levels.so2_annual_std
            .|| pollution_levels.so2_24h .> pollution_levels.so2_24h_std
            .|| pollution_levels.so2_1h .> pollution_levels.so2_1h_std_2022
            .|| pollution_levels.so2_3h .> pollution_levels.so2_3h_std_2022
            .|| pollution_levels.pm10_24h .> pollution_levels.pm10_24h_std_2022
            )

        # if percent ≤ 1 --> standard is getting tighter
        # in nonattainment if above the standard, or was already in nonattainment in the data
        in_nonattainment = vec(in_nonattainment .|| sum(data.reg_cf_B .- 1, dims = 2) .> 0)

    elseif percent == Inf

        in_nonattainment = zeros(length(T_c))
        in_nonattainment = convert(BitVector, in_nonattainment)
        in_nonattainment[1:2:end] .= true
        in_nonattainment = vec(in_nonattainment .|| sum(data.reg_cf_B .- 1, dims = 2) .> 0)

    else

        # vector of indicators for whether the county-industry is above the counterfactual standard
        in_nonattainment = zeros(length(T_c))
        in_nonattainment = convert(BitVector, in_nonattainment)
        in_nonattainment = vec(sum(data.reg_cf_B .- 1, dims = 2) .> 0)

    end

    println("$(sum(in_nonattainment)) counties in nonattainment at $(100*percent)% of the factual standard.")

    T_base = T_c
    T_base[in_nonattainment] = T_c[in_nonattainment] ./ maximum(data.reg_cf_T)

    η_base = η_c
    η_base[in_nonattainment, :] = η_c[in_nonattainment, :] .* maximum(data.reg_cf_B, dims = 1)

    # solve counterfactual: with nonattainment
    w_base, L_base, L_total_base, π_base, P_base, B_hat = solve_w(p, data, data.w, T_base, η_base, ones(size(L_total_c)), true, L_c, L_total_c)

    
    println("Step 4: Writing results.")
    
    # option value
    π_hat = (diag(π_base) ./ diag(π_c))

    # replace 0/0 with 1
    π_hat[isnan.(π_hat)] .= 1

    # get consumption price indices
    Pc_c = solve_P_cons(p, P_c)
    Pc_base = solve_P_cons(p, P_base)

    w_hat = w_base ./ w_c
    P_hat = Pc_base ./ Pc_c

    v_hat = (w_hat ./ P_hat)

    # no change in nonemployed consumption payoff
    v_hat = reshape(v_hat, 2, :)
    v_hat = vcat(v_hat, ones(size(v_hat[1, :]')))
    v_hat = reshape(v_hat, :, 1)

    # invert so factual relative to counterfactual, B_hat is currently counterfactual relative to factual
    welfare_hat = ( ( v_hat ) .* B_hat ) ./ π_hat.^(1 ./ p.ι)

    # population-weighted welfare: ∑_iζ_c_i
    Lm_base = L_total_base[1:3:end]
    Lo_base = L_total_base[2:3:end]
    Lu_base = L_total_base[3:3:end]
    Lm_c = L_total_c[1:3:end]
    Lo_c = L_total_c[2:3:end]
    Lu_c = L_total_c[3:3:end]

    # amenties welfare
    amenity_nonattain = prod(data.reg_cf_B[1:2:end,:], dims = 2) .!= 1
    amenity_nonattain = repeat(vec(amenity_nonattain), inner = 3)
    amenity_attain = convert(BitVector, 1 .- amenity_nonattain)

    welfare_hat_amenity =  B_hat 
    welfare_amenity_attain = L_total_base[amenity_attain]' * welfare_hat_amenity[amenity_attain] ./ sum(L_total_base[amenity_attain])
    welfare_amenity_nonattain = L_total_base[amenity_nonattain]' * welfare_hat_amenity[amenity_nonattain] ./ sum(L_total_base[amenity_nonattain])
    welfare_amenity_total = L_total_base' * welfare_hat_amenity
    welfare_amenity_m = Lm_base' * welfare_hat_amenity[1:3:end] / sum(Lm_base)
    welfare_amenity_o = Lo_base' * welfare_hat_amenity[2:3:end] / sum(Lo_base)
    welfare_amenity_u = Lu_base' * welfare_hat_amenity[3:3:end] / sum(Lu_base)

    # consumption welfare
    welfare_consumption = ( (w_base./Pc_base) ./ (w_c./Pc_c) )
    welfare_consumption = reshape(welfare_consumption, 2, :)
    welfare_consumption = vcat(welfare_consumption, ones(size(welfare_consumption[1, :]')))
    welfare_consumption = reshape(welfare_consumption, :, 1)
    welfare_consumption_total = L_total_base' * welfare_consumption
    welfare_consumption_m = Lm_base' * welfare_consumption[1:3:end] / sum(Lm_base)
    welfare_consumption_o = Lo_base' * welfare_consumption[2:3:end] / sum(Lo_base)
    welfare_consumption_u = Lu_base' * welfare_consumption[3:3:end] / sum(Lu_base)
    welfare_cons_attain = L_total_base[amenity_attain]' * welfare_consumption[amenity_attain] ./ sum(L_total_base[amenity_attain])
    welfare_cons_nonattain = L_total_base[amenity_nonattain]' * welfare_consumption[amenity_nonattain] ./ sum(L_total_base[amenity_nonattain])

    # total welfare
    welfare_hat_m = welfare_hat[1:3:end]
    welfare_hat_o = welfare_hat[2:3:end]
    welfare_hat_u = welfare_hat[3:3:end]
    welfare_total = L_total_base' * welfare_hat
    welfare_m = Lm_base' * welfare_hat_m / sum(Lm_base)
    welfare_o = Lo_base' * welfare_hat_o / sum(Lo_base)
    welfare_u = Lu_base' * welfare_hat_u / sum(Lu_base)
    welfare_attain = L_total_base[amenity_attain]' * welfare_hat[amenity_attain] ./ sum(L_total_base[amenity_attain])
    welfare_nonattain = L_total_base[amenity_nonattain]' * welfare_hat[amenity_nonattain] ./ sum(L_total_base[amenity_nonattain])

    #################################
    # save results
    #################################

    # save results: aggregate values
    welfare_df = DataFrame(
        total = Any[], # total welfare
        welfare_attain = Any[], welfare_nonattain = Any[],
        amenity_total = Any[], amenity_attain = Any[], amenity_nonattain = Any[],
        consumption_total = Any[], cons_attain = Any[], cons_nonattain = Any[],
        manufacturing = Any[], other = Any[], unemp = Any[],   # industry-specific welfare
        manu_cons = Any[], other_cons = Any[], unemp_cons = Any[],
        manu_amenity = Any[], other_amenity = Any[], unemp_amenity = Any[], total_in_na = Any[]
        )
    push!(welfare_df, [welfare_total[1] welfare_attain[1] welfare_nonattain[1] welfare_amenity_total[1] #=
    =# welfare_amenity_attain[1] welfare_amenity_nonattain[1] welfare_consumption_total[1] #=
    =# welfare_cons_attain[1] welfare_cons_nonattain[1] welfare_m welfare_o welfare_u welfare_consumption_m[1] #=
    =# welfare_consumption_o[1] welfare_consumption_u[1] welfare_amenity_m[1] welfare_amenity_o[1] welfare_amenity_u[1] sum(in_nonattainment)/2]
    )

    p.heterog ? h_s = "on" : h_s = "off"
                            
    CSV.write("data/counterfactual/welfare_total_heterog-$(h_s)_percent-$(percent).csv", welfare_df)

    # save results: market-specific values
    welfare_df = DataFrame(fips = Any[], industry = Any[], welfare = Any[])

    crosswalk = Array(CSV.read("data/simulation_fips_crosswalk.csv", DataFrame))

    fips = reshape([crosswalk crosswalk crosswalk]', :, 1)
    industry = repeat(["manufacturing"; "other"; "unemployed"], length(crosswalk))

    in_nonattainment = in_nonattainment[1:2:end]
    in_nonattainment = [in_nonattainment in_nonattainment in_nonattainment]
    in_nonattainment = reshape(in_nonattainment, :, 1)

    Pc_c = reshape(Pc_c, 2, :)
    Pc_c = vcat(Pc_c, zeros(size(Pc_c[2, :]')))
    Pc_c = reshape(Pc_c, :, 1)
    Pc_base = reshape(Pc_base, 2, :)
    Pc_base = vcat(Pc_base, zeros(size(Pc_base[2, :]')))
    Pc_base = reshape(Pc_base, :, 1)

    w_c = reshape(w_c, 2, :)
    w_c = vcat(w_c, zeros(size(w_c[2, :]')))
    w_c = reshape(w_c, :, 1)
    w_base = reshape(w_base, 2, :)
    w_base = vcat(w_base, zeros(size(w_base[2, :]')))
    w_base = reshape(w_base, :, 1)

    P_base = reshape(P_base, 2, :)
    P_base = vcat(P_base, zeros(size(P_base[2, :]')))
    P_base = reshape(P_base, :, 1)

    L_hat = L_total_base - L_total_c

    
    welfare_df = DataFrame([fips industry welfare_hat L_hat #=
                        =# B_hat v_hat #=
                        =# L_total_base L_total_c w_base w_c Pc_base Pc_c convert.(Int64, in_nonattainment)], :auto)

    rename!(welfare_df, [:fips, :industry, :welfare, :labor, :amenities, :real_wage, :base_labor, :cf_labor, :base_wage, :cf_wage, :base_c_price, :cf_c_price, :nonattain])

    p.heterog ? h_s = "on" : h_s = "off"
                            
    CSV.write("data/counterfactual/welfare_heterog-$(h_s)_percent-$(percent).csv", welfare_df)



end

# constructs the data tuple for use in the model
function load_data(p)

    ########################
    ### load data
    ########################

    ### crosswalks

    # minimal set of counties
    crosswalk = Array(CSV.read("data/simulation_fips_crosswalk.csv", DataFrame))
    # ap3/nei counties
    ap3_crosswalk = Array(CSV.read("raw/modified-ap3/fips.csv", DataFrame))
    ap3_crosswalk[ap3_crosswalk .== 12025] .= 12086
    # crosswalk for bea fips -> main fips
    mismatched_fips_crosswalk = CSV.read("data/mismatched-fips-crosswalk.csv", DataFrame)

    ### trade costs

    τ_df = Matrix(CSV.read("data/tradecosts/Tau$(p.year).csv", DataFrame, header = false))
    τ_crosswalk = convert(Vector{Int64}, τ_df[:,1])
    τ_df = τ_df[:,2:end]

    τ_crosswalk = DataFrame(fips_bea = τ_crosswalk)
    τ_crosswalk = leftjoin(τ_crosswalk, mismatched_fips_crosswalk, on = :fips_bea).fips
    τ_crosswalk[ismissing.(τ_crosswalk)] .= false
    τ_crosswalk = convert(Vector{Int64}, τ_crosswalk)

    τ_index = in(crosswalk).(τ_crosswalk) .* in(ap3_crosswalk).(τ_crosswalk)
    τ_index[ismissing.(τ_index)] .= false
    τ_index = convert(Vector{Bool}, τ_index)
    τ_df = τ_df[τ_index, τ_index]

    ### damage transport

    ap3_index = vec(in(crosswalk).(ap3_crosswalk) .* in(τ_crosswalk).(ap3_crosswalk))
    ap3_index[ismissing.(ap3_index)] .= false
    ap3_index = convert(Vector{Bool}, ap3_index)

    nh3 = CSV.read("data/modified-ap3/NH3_md_L_1997_prime_aged_for_counterfactual.csv", DataFrame) |>
        Matrix
    nox = CSV.read("data/modified-ap3/NOX_md_L_1997_prime_aged_for_counterfactual.csv", DataFrame) |>
        Matrix
    pm25 = CSV.read("data/modified-ap3/PM25_md_L_1997_prime_aged_for_counterfactual.csv", DataFrame) |>
        Matrix
    so2 = CSV.read("data/modified-ap3/SO2_md_L_1997_prime_aged_for_counterfactual.csv", DataFrame) |>
        Matrix
    voc = CSV.read("data/modified-ap3/VOC_md_L_1997_prime_aged_for_counterfactual.csv", DataFrame) |>
        Matrix

    damage_matrices = [nh3, nox, pm25, so2, voc]

    # cpi deflator to discount damages back to model year
    cpi = CSV.read("raw/fred/CPIAUCSL.csv", DataFrame)
    cpi = cpi[(Dates.year.(cpi.DATE) .== p.year .|| Dates.year.(cpi.DATE) .== 2020) .&& Dates.month.(cpi.DATE) .== 01,:]

    cpi = cpi.CPIAUCSL[1] / cpi.CPIAUCSL[2]

    # deflate, put into 1000s of dollars to match wages
    damage_matrices .= damage_matrices * cpi / 1000 

    ### emissions

    emission_df = CSV.read("data/emissions_cf.csv", DataFrame, header = false) |>
        Matrix
    emission_df = emission_df[ap3_index,:]
        
    ####  migration shares

    π_df = Matrix(CSV.read("data/migration_shares_matrix_total.csv", DataFrame))
    π_df = π_df[:,2:end]

    ### payroll and employment data
    payroll_df =
        CSV.read("data/payroll_cf.csv", DataFrame)
    payroll_df = payroll_df[payroll_df.year .== p.year .&& payroll_df.industry .== "m", :]

    payroll_other_df =
        CSV.read("data/payroll_cf.csv", DataFrame)
        payroll_other_df = payroll_other_df[payroll_other_df.year .== p.year .&& payroll_other_df.industry .== "o", :]

    payroll_unemp_df =
        CSV.read("data/payroll_cf.csv", DataFrame) 
        payroll_unemp_df = payroll_unemp_df[payroll_unemp_df.year .== p.year .&& payroll_unemp_df.industry .== "u", :]

    # indices for payroll
    county_index = (in(crosswalk[in(τ_crosswalk).(crosswalk)]).(payroll_df.fips))

    payroll_df = payroll_df[county_index, :]
    payroll_other_df = payroll_other_df[county_index, :]
    payroll_unemp_df = payroll_unemp_df[county_index, :]

    ### nonattainment status

    # regulation
    regulation_df =
        CSV.read("data/greenbook/nonattainment_status.csv", DataFrame) 
    regulation_df = regulation_df[regulation_df.year .== p.year, :]  
    regulation_df = regulation_df[:, [:fips, :sum_nonattainment]]
    rename!(regulation_df, [:fips, :treated])

    # regulation
    regulation_cf_df =
        CSV.read("data/greenbook/nonattainment_status.csv", DataFrame) 
    regulation_cf_df = regulation_cf_df[regulation_cf_df.year .== p.cf_year, :]  
    regulation_cf_df = regulation_cf_df[:, [:fips, :sum_nonattainment]]
    rename!(regulation_cf_df, [:fips, :treated])
    regulation_cf_df.treated .= 0

    sort!(regulation_df, [:fips])
    sort!(regulation_cf_df, [:fips])

    # set of counties in nonattainment dataset that match the other datasets
    reg_index = in(crosswalk).(regulation_df.fips) .* in(ap3_crosswalk).(regulation_df.fips)
    reg_index = convert(Vector{Bool}, reg_index)

    regulation_df = regulation_df[reg_index, :]
    regulation_cf_df = regulation_cf_df[reg_index, :]

    # baseline and counterfactual shocks to productivity and emissions price
    nonattainment_T =
        exp.(p.β_T .* convert(Array{Int64}, regulation_df.treated))
    nonattainment_B =
        exp.(p.β_E' .* convert(Array{Int64}, regulation_df.treated))

    nonattainment_cf_T =
        exp.(p.β_T .* convert(Array{Int64}, regulation_cf_df.treated))
    nonattainment_cf_B =
        exp.(p.β_E' .* convert(Array{Int64}, regulation_cf_df.treated))

    # match location-industry order of migration matrices (1-m, 1-o, 2-m, 2-o, ...)
    nonattainment_T = reshape([nonattainment_T ones(size(nonattainment_T))]', :, 1)
    nonattainment_cf_T = reshape([nonattainment_cf_T ones(size(nonattainment_cf_T))]', :, 1)

    nonattainment_B = hcat([reshape([nonattainment_B[:,i] nonattainment_B[:,i]]', :, 1) for i in eachindex(p.β_E)]...)
    nonattainment_cf_B = hcat([reshape([nonattainment_cf_B[:,i] nonattainment_cf_B[:,i]]', :, 1) for i in eachindex(p.β_E)]...)


    ########################
    ### data struct
    ########################

    # equilibrium labor shares
    L = reshape([payroll_df.employment payroll_other_df.employment]', :, 1)
    L_total = reshape([payroll_df.employment payroll_other_df.employment payroll_unemp_df.employment]', :, 1)
    #L = (π_df')^100 * L

    # L = payroll_df.emp / sum(payroll_df.emp) # labor distribution
    # set up data tuple
    data = (π = π_df,
            d = damage_matrices,
            L = L ./ sum(L_total),
            L_total = L_total ./ sum(L_total),
            L_level = L,
            cpi = cpi,
            e = kron(emission_df, [1; 0]),
            w = reshape([payroll_df.wage payroll_other_df.wage]', :, 1),
            P = ones(length(L)),
            τ = kron(τ_df, [1 1; 1 1]),
            reg_cf_T = nonattainment_cf_T./nonattainment_T,
            reg_cf_B = nonattainment_cf_B./nonattainment_B)

    # # make sure shares sum to 1
    # sort(sum(data.π, dims = 2), dims = 1)
    # any(sort(sum(data.π, dims = 1), dims = 2) .== 0)

    if all(data.reg_cf_B .== 1)
        @warn "All regulation shocks are zero."
    end

    return data

end